%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ks_het_firms_real_only.m
% Solves a GE heterogeneous firms investment model with macro shocks and real adjustment costs 
% via the Krusell Smith algorithm
% Ivan Alfaro, Nick Bloom and Xiaoji Lin 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

close all; clearvars; clc;

diary_record = 1;
if (diary_record)
    delete 'ks_het_firms.txt';
    diary 'ks_het_firms.txt';
    diary ON
end

%%%%initial program setup
%start tracking timing
tic; 
disp('%%%%%%%%%%%%%%%%%%%')
disp('Solving the het firms model with macro shocks via Krusell Smith')
disp(' ')

%set parameters
beta    = 1/1.0125; %household subjective discount factor
R       = 1/beta;   %implies the real interest rate in GE
delta   = 0.05;     %depreciation rate
alpha   = 0.70;      %capital share
c_k     = 0.20;    %adjustment cost parameter

%set grid dimensions for states (z,k,A,K)
%set up price grid
plb   = 0.7;     %price lower boundary (initial bisection lb)
pub   = 1.8;     %price upper boundary (initial bisection ub)
pnum  = 25;
p0    = linspace(plb,pub,pnum);
%control the price-clearing process
perrortol = 1e-4;
disttol = 1e-4; %tolerance for ignoring this point in the dist
pwindow = 0.05;
pcutoff = 15;

%%%%%%%%%%%%
%%%%This block sets up the grids and productivity transition matrices
%%%%%%%%%%%%
%set grid dimensions for states (z,sigmaz,k,A,sigmaA,K)

%set capital grid boundaries
knum = 210;
kmin = 0.2; %micro min
kmax = exp(log(kmin) - (knum-1) * log(1.0-delta));
k0   = exp(linspace(log(kmin),log(kmax),knum))';
k0   = k0(1:1:end);
knum = length(k0);
 
%set up macro capital grid
Knum = 20;  %macro capital grid
Kmin = 100; %macro min
Kmax = 450; %macro max
K0   = linspace(log(Kmin),log(Kmax),Knum)';
K0   = exp(K0);
 
%set up solution
maxfcstit   = 50;  %max number of iterations on the fcst rule
fcstdamp    = 0.05; %fcst rule updates 100*fcstdamp % of the way
maxvfit     = 1000; %max number of iterations on the value function
maxvferr    = 1e-3; %max abs error or change in the value functions
howardnum   = 30;   %number of Howard acceleratiTon steps within VFI steps
maxdistit   = 1000; %max number of distributional iterations
maxdisterr  = 1e-7; %max abs error of change in stationary dist
maxbunch    = 1e-3; %max bunching at grid endpoints
maxpit      = 100;   %max iters to find eqbm price
maxperr     = 1e-4; %max error in clearing price for simulations
checkbounds = 1;    %check boundary problem
ZeroBoundary = 1.e-4;                                                  % bound for fixing computer error when computing investment on grid        
fcsterrortol = 0.1; %the convergence tolerance for the forecast rules 
RMSEchangetol = 0.001; %the convergence tolerance for RMSE changes
R2changetol   = 0.01; %the convergence tolerance for R2 changes
maxDenHaanchangetol = 0.001; %the convergence tolerance for Den Haan max statistic
avgDenHaanchangetol = 0.001; %the convergence tolerance for Den Haan avg statistic

kfcsterrorstore = zeros(maxfcstit,1);
pfcsterrorstore = zeros(maxfcstit,1);
fcsterrstore    = zeros(maxfcstit,1);
    
%GE convergence criteria
%1 implies convergence of coefficients
%2 implies convergence of RMSE in changes
%3 implies convergence of R^2 in changes
%4 implies convergence of max Den Haan stat in changes
%5 implies convergence of avg Den Haan stat in changes
GEerrorswitch = 4; 
 
%%%%%%%%%%%%
%%%%This block sets up the grids and productivity and uncertainty transition matrices
%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% micro time-varying uncertainty
snum        = 2;                                  % number of firm specific uncertainty states
sigmaz_l    = 0.051;                              % low firm-specific sigma from Bloom et al 2018 RUBC
sigmaz      = [sigmaz_l sigmaz_l*4.1];            % low and high firm-specific sigma from Bloom et al 2018 RUBC
% macro time-varying uncertainty
Snum        = 2;                                  % number of aggregate uncertainty states
sigmaA_l    = 0.0067;                             % low aggregate sigma from Bloom et al 2018 RUBC
sigmaA      = [sigmaA_l sigmaA_l*1.6];            % low and high aggregate sigma from Bloom et al 2018 RUBC
PiSigma     = ([1-0.026 0.026 ;                   % transition matrix from Bloom et al RUBC 2018 same for aggregate productivity
                  1-0.943 0.943;]);
sigmaAinit  = ceil(snum/2);
Pisigma     = PiSigma;                            % transition matrix from Bloom et al RUBC 2018 same for firm-specific productivity

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% firm-specific productivity
znum    = 5;                                      % number of firm-specific productivity grid points
zbar    = 0;                                      % mean level of productivity
rhoz    = 0.95;                                   % persistence
stdz    = sqrt(sigmaz_l^2/(1-rhoz^2)); 
nsigmaz = 5;                                      % controls the boundary of z
Jen     = 2;                                      % parameter on Jensen adjustment
[z0, Piz] = TransitProb(rhoz,zbar,stdz,sigmaz,znum,snum,nsigmaz,Jen); % transition matrix
z0      = exp(z0)';
zmin    = min(z0);  
zmax    = max(z0);
intit    = (ceil(znum/2));                        % index of the median z

%set up macro productivity grid
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Anum    = 5;                                    % number of aggregate productivity grid points
Abar    = -1;                                    % mean level of productivity
rhoA    = 0.95;                                 % persistence of aggregate productivity
stdA    = sqrt(sigmaA_l^2/(1-rhoA^2));   
nsigmaA = (1.75);
[A0, PiA] = TransitProb(rhoA,Abar,stdA,sigmaA,Anum,Snum,nsigmaA,Jen);
A0      = exp(A0)';
Amin    = (min(A0));
Amax    = (max(A0));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%set up simulation
Tsim = 5000; %how many periods to simulate
Terg = 1000; %how many periods to discard to remove influence of initialization
Ttot = Tsim + Terg; %implied total number of periods
kinit = floor(knum/2); %starting point for micro capital simulation
zinit = floor(znum/2); %starting point for micro productivity simulation
Kinit = floor(Knum/2); %starting point for macro capital simulation
Ainit = floor(Anum/2); %starting point for macro productivity simulation
Sinit = floor(Snum/2); %starting point for macro sigma simulation
sinit = floor(snum/2); %starting point for micro sigma simulation
randseed = 2501; %random number generator seed, for reproducibility
rng(randseed);

% %set up plotting parameters
lwidnum  = 2; %line width for graphs
fsizenum = 12; %font size for graphs

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%do the exogenous macro simulation outside of the forecast rule loop
Asimpos         = zeros(Ttot,1);
Asimpos(1)      = Ainit;
PisumA          = cumsum(PiA,2);
Ashocks         = rand(Ttot,1);
sigmaAshocks    = rand(Ttot,1);
PisumSigma      = cumsum(PiSigma,2);
sigmaAsimpos    = zeros(Ttot,1);
sigmaAsimpos(1) = sigmaAinit;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for t = 2 : Ttot
     % time-varying uncertainty
     sct       = sigmaAsimpos(t-1);
     sprimect  = find(PisumSigma(sct,:) >= sigmaAshocks(t), 1); %     %compare today's uniform shock to transition matrix intervals
     sigmaAsimpos(t)= sprimect;        %aggregate volatility %     %store simulated value

     % agg prodcutivity
     Act       = Asimpos(t-1);
     Aprimect  = find(PisumA(Act,:,sct) >= Ashocks(t), 1);  %     %compare today's uniform shock to transition matrix intervals
     Asimpos(t) = Aprimect;    %     %store simulated value

 end

sigmaAsim = sigmaA(sigmaAsimpos);
%Asim      = A0(Asimpos);
Aggnum    = Anum*Knum*Snum;
statenum  = knum*snum*znum*Anum*Snum*Knum;

disp('Finished setting up the model.')
toc;
disp(' ')

%%%%%%%%%%%%
%%%%This block conducts the fcst rule iteration
%%%%%%%%%%%%

disp('Starting fcst rule iteration.')
toc;

%initialize fcst rule and policies and values
thetaKold = zeros(Anum,Snum,2); %log K' = thetaK(A,Sigma,1) + thetaK(A,Sigma,2)*log K
thetaK = 0*thetaKold;

thetaKold(:,:,1)=[...
    0.0367    0.0506
    0.0435    0.0398
    0.0333    0.0447
    0.0419    0.0623
    0.0674    0.0641];
thetaKold(:,:,2) =[  ...
    0.9930    0.9907
    0.9917    0.9927
    0.9939    0.9919
    0.9927    0.9888
    0.9883    0.9888];

thetapold = zeros(Anum,Snum,2); %log p = thetap(A,Sigma,1) + thetap(A,Sigma,2)*log K
thetap = 0*thetapold;

thetapold(:,:,1)=[...    
    2.7761    2.8880
    2.8423    3.0371
    3.0300    3.1786
    3.1170    3.1637
    2.9215    3.1863];
thetapold(:,:,2) =[...
  -0.4954   -0.5230
   -0.5063   -0.5482
   -0.5415   -0.5726
   -0.5594   -0.5692
   -0.5265   -0.5734];

kprimeindold = zeros(Anum,Snum,Knum,znum,snum,knum); %(A,K,S,z,s,k)
kprimeindold(:,:,:,:) = kinit;
kprimeind    = kprimeindold;
kprimeold    = k0(kprimeindold);
kprime       = k0(kprimeind);

z1         = permute(repmat(z0,[1 snum knum  knum  pnum]),[4 1 2 3 5]); %(k',z,s,k,p) 
k1         = permute(repmat(k0,[1 znum snum  knum  pnum]),[4 2 3 1 5]); %(k',z,s,k,p) 
kprime1    = permute(repmat(k0,[1 znum snum  knum  pnum]),[1 2 3 4 5]); %(k',z,s,k,p) 
p1         = permute(repmat(p0',[1 znum snum knum ]),[2 3 4 1]); %(z,s,k,p)
%kprimep    = permute(repmat(k0,[1 znum snum pnum]),[2 3 1 4]); %(z,s,k,p);
kprimepold = permute(repmat(k0,[1 znum snum pnum]),[2 3 1 4]); %(z,s,k,p);
z1p        = repmat(z0,[1 snum knum pnum]); %(z,s,k,p);

p0mat      = permute(repmat(p0',[1 znum,snum,knum,knum]),[5 2 3 4 1]);%(k',z,s,k,p)
k1_z_alpha = z1.*k1.^alpha;                 %precomputed for optimal labor and output

zmat      = permute(repmat(z0,[1 Anum Snum Knum snum knum knum]),[2 3 4 1 5 6 7]); %(A,S,K,z,s,k,k')
Amat      = repmat(A0,[1 Snum Knum znum snum knum knum]);                      %(A,S,K,z,s,k,k')
kmat      = permute(repmat(k0,[1 Anum Snum Knum znum snum knum]),[2 3 4 5 6 1 7]); %(A,S,K,z,s,k,k')
kprimemat = permute(repmat(k0,[1 Anum Snum Knum znum snum knum]),[2 3 4 5 6 7 1]);%(A,S,K,z,s,k,k')
imattemp  = kprimemat - (1-delta)*kmat;
imattemp(abs(imattemp)<=ZeroBoundary) =0;
imat       = imattemp;

ivaltemp = kprime1 - (1-delta)*k1;                              %(k',z,s,k,p)
ivaltemp(abs(ivaltemp)<=ZeroBoundary)=0; 
ival     = ivaltemp;  
          
PiAA        = repmat(PiA,[1 1 1 pnum]);   %(A,A',S,p)
PiSS        = repmat(PiSigma,[1 1 pnum]); %(S,S',p);
Pizz        = reshape(Piz,[znum znum*snum]); %(z,z'*s);
Vold        = ones(Anum,Snum,Knum,znum,snum,knum);%(A,S,K,z,s,k)
V           = Vold;
 
pRMSEstore =zeros(Anum,Snum,maxfcstit);     KRMSEstore =pRMSEstore;         
pR2store   = KRMSEstore;                    KR2store=KRMSEstore;
KRMSEchangestore= zeros(maxfcstit,1);       pRMSEchangestore=KRMSEchangestore;
KR2changestore  = KRMSEchangestore;            pR2changestore=KRMSEchangestore; pmaxDenHaanstat=KRMSEchangestore;
KmaxDenHaanstat = KRMSEchangestore;          pavgDenHaanchangestore=KRMSEchangestore;
KmaxDenHaanchangestore=KRMSEchangestore;    pavgDenHaanstat=KRMSEchangestore;
KavgDenHaanstat= KRMSEchangestore;          pmaxDenHaanchangestore=KRMSEchangestore;
KavgDenHaanchangestore= KRMSEchangestore;

%%
for fcstit=1:maxfcstit

    disp(' ')
    disp(['Doing fcst rule iteration ' num2str(fcstit) '.'])
    disp(' ')

    %initialize the VF and policies
    if (fcstit==1)  
        %in this case, initialize using a stupid guess
        Vold          = ones(Anum,Snum,Knum,znum,snum,knum);%(A,S,K,z,s,k)
        V             = Vold;
        V_temp        = V;
        kprimeindold = kinit*ones(Anum,Snum,Knum,znum,snum,knum);    %(A,S,K,z,s,k) initialize the index of policies
    elseif (fcstit>1)
        %in this case, initialize using last GE iteration
        Vold = V;
        kprimeindold = kprimeind;
        V_temp        = V;
    end     
    
    %pre compute fcst matrices based on thetaKold and thetapold
    thetaKoldmat = exp(repmat(thetaKold(:,:,1),[1 1 Knum]) + repmat(thetaKold(:,:,2),[1 1 Knum]).*permute(repmat(log(K0),[1 Anum Snum]),[2 3 1])); %(A,S,K) contains fcst K'(A,S,K)
    thetaKoldmat(thetaKoldmat>Kmax) = Kmax-0.001;
    thetaKoldmat(thetaKoldmat<Kmin) = Kmin+0.001;
    
    thetapoldmat = exp(repmat(thetapold(:,:,1),[1 1 Knum]) + repmat(thetapold(:,:,2),[1 1 Knum]).*permute(repmat(log(K0),[1 Anum Snum]),[2 3 1])); %(A,S,K) contains fcst p'(A,S, K)
    thetapoldmat(thetapoldmat>pub) = pub;
    thetapoldmat(thetapoldmat<plb) = plb;
        
    %set up return matrix
    
      thetapmat = repmat(thetapoldmat,[1 1 1 znum snum knum knum]);     %(A,S,K,z,s,k,k')
      ymat      = Amat.*zmat.*(kmat.^alpha);                            %(A,S,K,z,s,k,k')
      ACmat     = c_k*(imat~=0).*ymat;                                  %(c_k/2) * (( ival/kval)^2)*kval;
      Rmat      = thetapmat.*(ymat - imat - ACmat);                          %(A,S,K,z,s,k,k')
      Rmat2     = permute(reshape(Rmat,[Anum*Knum*Snum,znum*snum*knum,knum]),[2 3 1]); %(z*k*s,k',A*K*S)
      Kprimewgtmat = zeros(1,Anum*Knum*Snum);
      disp('Done with return matrix setup.')
       
    %do VF iteration 
    for vfit=1:maxvfit    
            
            % vectorize the code
            Kprimevalmat  = repmat(reshape(thetaKoldmat,[1 Anum*Snum*Knum]),[Knum 1]); %(K',A*S*K)
            K0mat         = repmat(K0,[1 Anum*Knum*Snum]); %(K',A*S*K)
            Kprimeindmat  = sum(Kprimevalmat>=K0mat,1);     %(1,A*S*K)    
             % guard against off grid point values
            if sum(Kprimeindmat==Knum)>1 
                Kprimeindmat(Kprimeindmat==Knum) = Knum-1;
                Kprimewgtmat(Kprimeindmat==Knum) = 1.0;
            elseif sum(Kprimeindmat==0)>1
                Kprimeindmat(Kprimeindmat==0) = 1;
                Kprimewgtmat(Kprimeindmat==0) = 0;
            else
                Kprimewgtmat  = (reshape(thetaKoldmat,[1 Anum*Knum*Snum])'-K0(Kprimeindmat))./(K0(Kprimeindmat+1)-K0(Kprimeindmat));
            end 
           
         % applying Howard (policy iteration)
            for howct=1:howardnum

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
                EVMAT_H  = zeros(knum,Anum,Snum,Knum,znum,snum); %(k',A,S,K,z,s)
                Kprimeindmat2 = reshape(Kprimeindmat,[Anum Snum Knum]);
                Kprimewgtmat2 = reshape(Kprimewgtmat,[Anum Snum Knum]);

                for Act = 1:Anum
                    for Sct = 1:Snum
                            for Kct = 1:Knum
                                for zct = 1:znum 
                                    for sct =1:snum   
                                        temp = squeeze((1-Kprimewgtmat2(Act,Sct,Kct))*Vold(:,:, Kprimeindmat2(Act,Sct,Kct),:,:,:)...
                                              + Kprimewgtmat2(Act,Sct,Kct)*Vold(:,:,Kprimeindmat2(Act,Sct,Kct)+1,:,:,:));  %(A',S',z',s',k')  

                                        temp1 =  PiSigma(Sct,:)*reshape(PiA(Act,:,Sct)...
                                                                *reshape(temp,[Anum,Snum*znum*snum*knum]),[Snum,znum*snum*knum]);  

                                        EVMAT_H(:,Act,Sct,Kct,zct,sct) = Pisigma(sct,:)*reshape(Piz(zct,:,sct)*reshape(temp1,[znum,snum*knum]),[snum,knum]); 
                                                    %(z',s',k') to (z',s'*k') to (1,s',k') to    %(k',1)
                                    end
                            end
                        end
                    end
                end
                     
                
                EVMAT_H2 = permute(reshape(repmat(EVMAT_H,[1 1 1 1 1 1 knum]),[knum Aggnum znum*snum*knum]),[3 1 2]); %(zks,k',K*A*S)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
                RHSMAT       = Rmat2 + beta*EVMAT_H2;                       %(zks,k',K*A*S)
                RHSMAT2      = reshape(permute(RHSMAT,[1 3 2]),[statenum,knum]);%(z*s*k*A*K*S,k')
                kprimeinddum = reshape(permute(kprimeindold,[4 5 6 1 2 3]),[statenum 1]); %(z*s*k*A*K*S,1)
                kprime_ind   = sub2ind(size(RHSMAT2),(1:statenum)',kprimeinddum);%(z*s*k*A*K*S,1)
                Vtemp        = RHSMAT2(kprime_ind); %(z*s*k*A*S*K,1)
                V            = permute(reshape(Vtemp,[znum snum knum Anum Snum Knum ]),[4 5 6 1 2 3]);   %(z,s,k,A,S,K) to (A,S,K,z,s,k)        
                Vold = V;
         
            end
            
        % compute the expected firm value and find the optimal policy after
           % vectorice the step by step computing conditional expectation 
                EVMAT_H  = zeros(knum,Anum,Snum,Knum,znum,sct); %(k',A,S,K,z,s)
                Kprimeindmat2 = reshape(Kprimeindmat,[Anum Snum Knum]);
                Kprimewgtmat2 = reshape(Kprimewgtmat,[Anum Snum Knum]);

                for Act = 1:Anum
                    for Sct = 1:Snum
                            for Kct = 1:Knum
                                for zct = 1:znum 
                                    for sct =1:snum   
                                        temp = squeeze((1-Kprimewgtmat2(Act,Sct,Kct))*Vold(:,:, Kprimeindmat2(Act,Sct,Kct),:,:,:)...
                                              + Kprimewgtmat2(Act,Sct,Kct)*Vold(:,:,Kprimeindmat2(Act,Sct,Kct)+1,:,:,:));  %(A',S',z',s',k')  

                                        temp1 =  PiSigma(Sct,:)*reshape(PiA(Act,:,Sct)...
                                                                *reshape(temp,[Anum,Snum*znum*snum*knum]),[Snum,znum*snum*knum]);  %(1,S',z',s',k') to %(S',z'*s'*k') to (z',s',k')

                                        EVMAT_H(:,Act,Sct,Kct,zct,sct) = Pisigma(sct,:)*reshape(Piz(zct,:,sct)*reshape(temp1,[znum,snum*knum]),[snum,knum]); 
                                                    %(z',s',k') to (z',s'*k') to (1,s',k') to    %(k',1)
                                    end
                            end
                        end
                    end
                end
                     
                
          EVMAT_H3 = permute(reshape(repmat(EVMAT_H,[1 1 1 1 1 1 knum]),[knum Aggnum znum*snum*knum]),[3 1 2]); %(zks,k',K*A*S)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        RHSMAT       = Rmat2 + beta*EVMAT_H3;                       %(zks,k',K*A*S)
        
        [Vdum,kprimeinddum] = max(RHSMAT,[],2);
        V         = permute(reshape(squeeze(Vdum),znum,snum,knum,Anum,Snum,Knum),[4 5 6 1 2 3]);
        kprimeind = permute(reshape(squeeze(kprimeinddum),znum,snum,knum,Anum,Snum,Knum),[4 5 6 1 2 3]);               
        kprime    = k0(kprimeind);
         
        %evaluate error of VF and policies
        vferr = max(abs(V(:)-Vold(:)));
        polerr = max(abs(kprime(:)-kprimeold(:)));
     
        %display diagnostics periodically
        if (mod(vfit,10)==1)
            disp(['On VF iter ' num2str(vfit) ' |V-Vold| = ' num2str(vferr) ' |k - kold| = ' num2str(polerr)])   
        end
        
        %exit VF loop if converged
        if (vferr<maxvferr) 
            break
        end
        
        %if haven't converged, update and move on
        kprimeold = kprime;
        kprimeindold = kprimeind;
        Vold = V;
    
    end
    
    disp('Done with VFI.')
    disp(' ')
    toc;
    
    I = kprime -(1-delta)*permute(repmat(k0,[1 Anum Snum Knum znum snum]),[ 2 3 4 5 6 1]); 
    I(abs(I)<=ZeroBoundary) =0;

  %%   
    %simulate the model, with market clearing loop
    
    %initialize macro capital/price & micro dist
    Ksim     = zeros(Ttot,1);
    Ksim(1)  = K0(Kinit);
    psim     = zeros(Ttot,1);
    Csim     = zeros(Ttot,1);
    Isim     = zeros(Ttot,1);
    ACsim    = zeros(Ttot,1);
    Ysim     = zeros(Ttot,1);
    Kfcstsim = zeros(Ttot,1);
    pfcstsim = zeros(Ttot,1);
    perrsim  = zeros(Ttot,1);
    Nsim     = zeros(Ttot,1);

     %initialize the VF and policies
    if (fcstit==1)  
        %in this case, initialize using a stupid guess
    %     load distold.mat
        distold          = zeros(znum,snum,knum);
        distzsk          = zeros(znum,snum,knum,Ttot);
        distold(zinit,:,:) = 1;
        distold          = distold/sum(distold(:));   
        distzsk(:,:,:,1) = distold;
    %dist = distold;
    elseif (fcstit>1)
        %in this case, initialize using last GE iteration
        distzsk(:,:,:,1) = distold;
    end     
    
    %actually conduct the simulation
    disp('Simulating the model.')

    for t=1:(Ttot-1)
        %extract macro prod
        Act       = Asimpos(t);
        Aval      = A0(Act);
        Sct       = sigmaAsimpos(t);
        sigmaAval = sigmaA(Sct);
        %extract macro capital and forecasts
        Kval      = Ksim(t);
        Kprimeval = exp(thetaKold(Act,Sct,1) + thetaKold(Act,Sct,2)*log(Kval));
        Kprimeind = sum(Kprimeval>=K0);
        
        % guard against off grid point values
        if (Kprimeind==Knum) 
            Kprimeind = Knum-1;
            Kprimewgt = 1.0;
        elseif (Kprimeind==0)
            Kprimeind = 1;
            Kprimewgt = 0.0;
        else
           Kprimewgt  = (Kprimeval-K0(Kprimeind))/(K0(Kprimeind+1)-K0(Kprimeind));
        end 
        Kfcstsim(t+1) = Kprimeval;
        
        pfcstval    = exp(thetapold(Act,Sct,1) + thetapold(Act,Sct,2)*log(Kval));
        pfcstsim(t) = pfcstval;
        
        %initialize price iteration
            EVMAT  = zeros(knum,znum,snum);
            temp   = squeeze((1.0-Kprimewgt)*squeeze(V(:,:,Kprimeind,:,:,:))...
                        + Kprimewgt*squeeze(V(:,:,Kprimeind+1,:,:,:)));             %(A',S',z',s',k')                                
            temp1  = PiSigma(Sct,:)*reshape(PiA(Act,:,Sct)*reshape(temp,[Anum,Snum*znum*snum*knum]),...
                    [Snum,znum*snum*knum]);  %(1,S',z',s',k') to %(S',z'*s'*k') to (z',s',k')
                    
            for zct = 1:znum 
                for sct =1:snum   
                     EVMAT(:,zct,sct) = Pisigma(sct,:)*reshape(Piz(zct,:,sct)*reshape(temp1,[znum,snum*knum]),[snum,knum]); 
                                %(z',s',k') to (z',s'*k') to (1,s',k') to    %(k',1)
                 end
            end
            
         disttemp = repmat(distzsk(:,:,:,t),[1 1 1 pnum]);%(z,s,k,p)       
           
     % vectorize the code
          EVMAT_H  = repmat(EVMAT,[1 1 1 knum pnum]);                     %(k',z,s,k,p)  
          yval     = Aval*k1_z_alpha;                                     %(k',z,s,k,p)        
          ACval    = c_k*(ival~=0).*yval;                                 %(k',z,s,k,p)
          divval   = (yval - ival - ACval);                               %(k',z,s,k,p)
          Rmatp    = p0mat.*divval;                                       %(k',z,s,k,p)
        %  clear lval yval ivaltemp ival ACval Rmatp
          RHSMAT2 = Rmatp + beta*EVMAT_H;
          [~,kprimeinddum] = max(RHSMAT2,[],1);
          kprimeindp = squeeze(kprimeinddum);
          kprimep    = k0(kprimeindp);
                 
           yvalp = squeeze(yval(1,:,:,:,:));
           ivalptemp = (kprimep - (1-delta)*squeeze(k1(1,:,:,:,:)));
           ivalptemp(abs(ivalptemp)<=ZeroBoundary)=0; 
           ivalp     = ivalptemp;
           acvalp    = c_k*(ivalp~=0).*yvalp;
           divvalp   = (yvalp-ivalp-acvalp) ;

           Yp0        = squeeze(sum(sum(sum(yvalp.*disttemp))));
           Ip0        = squeeze(sum(sum(sum(ivalp.*disttemp))));
           ACp0       = squeeze(sum(sum(sum(acvalp.*disttemp))));
           Kbarprimep0= squeeze(sum(sum(sum(kprimep.*disttemp))));
           polmatp_interp = kprimeindp; %(z,s,k,n,p)
           Cvalp = Yp0-Ip0-ACp0;
           Cp0   = Cvalp;
           ep0   = 1./p0'-Cvalp;

      %set up the boundaries of the bisection
        pvala = plb;
        pvalb = pub;
        pvalc = pvala + (0.67)*(pvalb-pvala);

        %iterate over the market-clearing value of p, for each value
        %redoing the optimization of the RHS of the Bellman equation
        %and recomputing policies
        for piter=1:maxpit
            %initialize the price-dependent policies to 0
    %         kprimeinddum(:,:) = 0; kprimeindp(:,:) = 0; 

            %brent setup
            if (piter==1); pval = pvala;end
            if (piter==2); pval = pvalb;end
            if (piter==3); pval = pvalc;end

            %brent restart
            if (piter==pcutoff)
                pvala = pfcstsim(t)-4.0*pwindow;
                pval = pvala;
            elseif (piter==pcutoff+1) 
                pvalb =  pfcstsim(t)+4.0*pwindow;
                pval = pvalb;
            elseif (piter==pcutoff+2)
                pvalc = pvala + (0.67) * (pvalb-pvala); 
                pval = pvalc;
            end 

            %if not restarting or initializing
            if ((piter>3 && piter<pcutoff) || (piter>pcutoff+2)) 
                %first, try inverse quadratic interpolation of the excess demand function
                pval = ( pvala * fb * fc ) / ( (fa - fb) * (fa - fc) ) ...
                    + ( pvalb * fa * fc ) / ( (fb - fa) * (fb - fc ) ) ...
                    + ( pvalc * fa * fb ) / ( (fc - fa) * (fc - fb ) );
                %if it lies within bounds, and isn't too close to the bounds, then done
                %o/w, take bisection step
                if (min(abs(pvala - pval),abs(pvalb-pval))<abs((pvalb-pvala)/9)) || (pval<pvala) || (pval>pvalb)
    %                 ((minval( (/ abs(pvala - pval), abs(pvalb-pval) /) )<&
    %                     abs( (pvalb-pvala)/dble(9.0) ) ).or.(pval<pvala).or.(pval>pvalb))   then
                    pval = (pvala + pvalb) /2;
                end
            end

            %actually evaluate consumption approximation function
            pval = min(max(p0(1),pval),p0(pnum));         %minval( (/ maxval( (/ p0(1) , pval /) ) , p0(pnum) /) )   
            pind = sum(pval>=p0);%pind = pnum/2;  call hunt(p0,pnum,pval,pind)
    %       pwgt = (pval - p0(pind)) / (p0(pind + 1) - p0(pind));

            % guard against off grid point values
            if (pind==pnum) 
                pind = pnum-1;
                pwgt = 1.0;
            elseif (pind==0)
                pind = 1;
                pwgt = 0.0;
            else
                pwgt = (pval - p0(pind)) / (p0(pind + 1) - p0(pind));
            end 

            %actually perform linear interpolation of the consumption function
            CvalpNew = Cp0(pind)*(1-pwgt) + Cp0(pind+1)*pwgt;

            %are you initializing the brent?
            if (piter==1); fa = (1/pval) - CvalpNew;end
            if (piter==2); fb = (1/pval) - CvalpNew;end
            if (piter==3); fc = (1/pval) - CvalpNew;end

            %are you restarting the brent?
            if (piter==pcutoff);   fa = (1/pval) - CvalpNew; end
            if (piter==pcutoff+1); fb = (1/pval) - CvalpNew; end
            if (piter==pcutoff+2); fc = (1/pval) - CvalpNew; end

            %what is the error given by this implied consumption?
            perror = (1/pval) - CvalpNew;

            %if not restarting or initializing
            if ((piter>3 && piter<pcutoff)||(piter>pcutoff+2))  
                if (perror<0) 
                    pvalc = pvalb; fc = fb;
                    pvalb = pval; fb = perror;
                    %pval a doesn't change
                elseif (perror>=0) 
                    pvalc = pvala; fc = fa;
                    pvala = pval; fa = perror;
                    %pval b doesn't change
                end
            end

            %exit criterion for market-clearing
            perror = log(pval*CvalpNew);
            if (abs(perror)<perrortol && piter>2) 
                break
            end  %piter


        end

        %insert market-clearing price and other linearly interpolated stuff into sim series
        psim(t) = pval; %this is the most recently run p from the clearing algorithm
        Csim(t) = CvalpNew; %this is already the linearly interpolated consumption C
        Ysim(t) = Yp0(pind)*(1.0-pwgt) + Yp0(pind+1)*pwgt; %linearly interpolated output Y
        Isim(t) = Ip0(pind)*(1.0-pwgt) + Ip0(pind+1)*pwgt; %linearly interpolated investment I
        ACsim(t) = ACp0(pind)*(1.0-pwgt) + ACp0(pind+1)*pwgt; %linearly interpolated ACk 
        Ksim(t+1) = Kbarprimep0(pind)*(1.0-pwgt) + Kbarprimep0(pind+1)*pwgt; %linearly interpolated K' 

        %now that the market-clearing price is determined, move on to insert weight into the next period, according to the 
        %linearly interpolated rule

          %output price stats on certain periods
          if (mod(t,50)==1)
                disp(['t = ',num2str(t), ' iter ',num2str(piter),' p ',num2str(pval),' err ',num2str(perror), ' A ',num2str(A0(Act)),' Kprime ' ,num2str(Ksim(t+1))])        
          end

    
       for zct=1:znum
           for sct =1:snum
            for kct=1:knum %endogct=1,numendog
                if (distzsk(zct,sct,kct,t)>disttol)       

                %based on the latest price, what is the policy here at pind?
                polstar = polmatp_interp(zct,sct,kct,pind);

                %insert distributional weight in appropriate slots next period
                distzsk(:,:,polstar,t+1) = distzsk(:,:,polstar,t+1) + ...
                    repmat(Piz(zct,:,sct)',1,snum).*repmat(Pisigma(sct,:),znum,1)*distzsk(zct,sct,kct,t) * (1.0-pwgt);

                %based on the latest price, what is the policy here at pind+1?
                polstar = polmatp_interp(zct,sct,kct,pind+1);

                %insert distributional weight in appropriate slots next period
                distzsk(:,:,polstar,t+1) = distzsk(:,:,polstar,t+1) + ...
                    repmat(Piz(zct,:,sct)',1,snum).*repmat(Pisigma(sct,:),znum,1)*distzsk(zct,sct,kct,t) * pwgt;


                end 
            end  %kct
           end %sct
       end  %zct

    %     %now, round to make sure that you're ending up with a distribution which makes sense each period
         distzsk(:,:,:,t+1) = distzsk(:,:,:,t+1)./sum(sum(sum(distzsk(:,:,:,t+1))));

        end

        distold = distzsk(:,:,:,end-200);

       if (checkbounds==1) 

            TOPksim = squeeze(sum(sum(distzsk(:,:,knum,:))));
            BOTksim = squeeze(sum(sum(distzsk(:,:,1,:))));

            TOPkmax = max(TOPksim((Terg+1):(Ttot-1)));
            BOTkmax = max(BOTksim((Terg+1):(Ttot-1)));

            disp(' ')
            disp('Checking for bounds of state space in simulation.');
            disp(['Top k = ' num2str(TOPkmax),'Bottom k = ' num2str(BOTkmax) ])    

        end
    
%%       
    disp('Done with simulation.')
    toc;
        
    %update the forecast rules, test for convergence
    Kest = Ksim((Terg):(Ttot-2));
    pest = psim((Terg):(Ttot-2));
    Kprimeest = Ksim((Terg+1):(Ttot-1));
    Asimposest = Asimpos((Terg):(Ttot-2));
    sigmaAsimposest = sigmaAsimpos((Terg):(Ttot-2));
    
    Aobsct = zeros(Anum,1);
    KPred  = zeros(length(Kprimeest),1);
    pPred  = zeros(length(pest),1);
    
    for Act=1:Anum
        for Sct = 1:Snum
           samp = (Asimposest==Act & sigmaAsimposest==Sct);
           X = [ones(sum(samp),1) log(Kest(samp))];
           Y = log(Kprimeest(samp));
           betaOLS = (X'*X)\(X'*Y);
           thetaK(Act,Sct,:) = betaOLS';
           KPred(samp) = betaOLS'*X';

           X = [ones(sum(samp),1) log(Kest(samp))];
           Y = log(pest(samp));
           betaOLS = (X'*X)\(X'*Y);
           thetap(Act,Sct,:) = betaOLS';
           pPred(samp) = betaOLS'*X';

           Aobsct(Act,Sct) = sum(samp);
        end
        
    end
    
    [~, ~, ~, ~, e] = regress(KPred,[ones(length(KPred),1) log(Kprimeest)]); R2K(fcstit) = e(1);
    [~, ~, ~, ~, e] = regress(pPred,[ones(length(pPred),1) log(pest)]);      R2p(fcstit) = e(1);
    
    thetaKerr = max(abs(thetaK(:)-thetaKold(:)));
    thetaperr = max(abs(thetap(:)-thetapold(:)));
    fcsterr   = max([thetaKerr,thetaperr]);
    %insert errors into storage
    kfcsterrorstore(fcstit) = thetaKerr;
    pfcsterrorstore(fcstit) = thetaperr;
    fcsterrstore(fcstit)    = fcsterr;
  
    disp (' ')
    disp('Fcst rules & new estimates')
    disp(['Max abs err ' num2str(fcsterr)])
    disp(' ')
    disp('Capital: Old, New')
    disp(num2str([thetaKold thetaK]))
    disp('Capital: Err')
    disp(num2str([thetaK-thetaKold]))
    disp(' ')
    disp('Price: Old, New')
    disp(num2str([thetapold thetap]))
    disp('Price: Err')
    disp(num2str([thetap-thetapold]))

    disp('R2K R2p')
    disp(num2str([R2K(fcstit),R2p(fcstit)]))
   
    pmean = zeros(Anum,Snum); pSSE = pmean; Kmean=pmean; KSSE=pmean;pSST=pmean; KSST = pmean;

    %Terg = 100; Ttot = 500;
    %loop over aggregate states
    for Act=1:Anum
        for Sct=1:Snum
        %find mean and fcst errors first for price
            perct = 0;
            for t=(Terg+1):(Ttot-1)
                if Asimpos(t)==Act && sigmaAsimpos(t)==Sct
                    perct = perct+1;
                    pmean(Act,Sct) = pmean(Act,Sct) + log(psim(t));
                    pSSE(Act,Sct)  = pSSE(Act,Sct) + ( log(psim(t)) - log(pfcstsim(t)) )^2;
                end %Act and Sct flag
            end %t

            %normalize mean, compute RMSE for price
            pmean(Act,Sct) = pmean(Act,Sct)/perct;
            pRMSEstore(Act,Sct,fcstit) = sqrt( pSSE(Act,Sct) / perct);

        %then for capital
            perct = 0;
            for t=(Terg+1):(Ttot-1)
                if Asimpos(t)==Act&& sigmaAsimpos(t)==Sct
                    perct = perct+1;
                    Kmean(Act,Sct) = Kmean(Act,Sct) + log(Ksim(t));
                    KSSE(Act,Sct) = KSSE(Act,Sct) + ( log(Ksim(t)) - log(Kfcstsim(t)) )^2;
                end %Act and Sct flag
            end %t

            %normalize mean, compute RMSE for capital
            Kmean(Act,Sct) = Kmean(Act,Sct)/perct;
            KRMSEstore(Act,Sct,fcstit) = sqrt( KSSE(Act,Sct) / perct);

        %now compute the total sum squares
            for t= (Terg+1):(Ttot-1)
                if Asimpos(t)==Act&& sigmaAsimpos(t)==Sct
                    pSST(Act,Sct) = pSST(Act,Sct) + (log(psim(t)) - pmean(Act,Sct))^2;
                    KSST(Act,Sct) = KSST(Act,Sct) + (log(Ksim(t)) - Kmean(Act,Sct))^2;
                end %Act and Sct flag
            end %t

            %now, compute the R^2 values
            pR2store(Act,Sct,fcstit) = 1 - ( pSSE(Act,Sct) / pSST(Act,Sct) );
            KR2store(Act,Sct,fcstit) = 1- ( KSSE(Act,Sct) / KSST(Act,Sct) );
    
        end %Sct
    
    end %Act

      %record the change of the RMSE's and the R^2's
        if (fcstit==1) 
 
            %in this case, it isn't change, just absolute RMSE and distance of R^2 from 1
            KRMSEchange = max(max(abs(KRMSEstore(:,:,fcstit))));
            pRMSEchange = max(max(abs(pRMSEstore(:,:,fcstit))));
            KR2change = max(max(abs(1 - KR2store(:,:,fcstit))));
            pR2change = max(max(abs(1 - pR2store(:,:,fcstit))));
        else 
 
            %now, it's the change in the metrics
            KRMSEchange = max(max(abs(KRMSEstore(:,:,fcstit)-KRMSEstore(:,:,fcstit-1))));
            pRMSEchange = max(max(abs(pRMSEstore(:,:,fcstit)-pRMSEstore(:,:,fcstit-1))));
            KR2change = max(max(abs(KR2store(:,:,fcstit)-KR2store(:,:,fcstit-1))));
            pR2change = max(max(abs(pR2store(:,:,fcstit)-pR2store(:,:,fcstit-1))));
        end 
 
        %store the RMSE and R2 errors
        KRMSEchangestore(fcstit) = KRMSEchange;
        pRMSEchangestore(fcstit) = pRMSEchange;
        KR2changestore(fcstit) = KR2change;
        pR2changestore(fcstit) = pR2change;
  
    
    %%%Compute the Den Haan fcst series

        %initialize the entire fcst series
        pDenHaanfcst = zeros(Ttot,1);
        KDenHaanfcst = zeros(Ttot,1);

        %initialize the period before the sample
        pDenHaanfcst(Terg) = log(psim(Terg));
        KDenHaanfcst(Terg) = log(Ksim(Terg));

        %initialize the statistics
        pmaxDenHaanstat(fcstit) = 0;
        KmaxDenHaanstat(fcstit) = 0;
        pavgDenHaanstat(fcstit) = 0;
        KavgDenHaanstat(fcstit) = 0;

        %loop over the periods
        for t= (Terg+1):(Ttot-1)

            %get states from last period, when the capital forecast was being made
            Act = Asimpos(t-1);
            Sct = sigmaAsimpos(t-1);

            %insert the capital fcst
            KDenHaanfcst(t) = thetaK(Act,Sct,1) + thetaK(Act,Sct,2)*KDenHaanfcst(t-1);

            %now, get states from this period, when the price forecast is being made
            Act = Asimpos(t);
            Sct = sigmaAsimpos(t);
            
            %insert the price fcst
            pDenHaanfcst(t) = thetap(Act,Sct,1) + thetap(Act,Sct,2)*KDenHaanfcst(t);

            %iterate the max statistics
            pmaxDenHaanstat(fcstit) = max(pmaxDenHaanstat(fcstit), abs(pDenHaanfcst(t) - log(psim(t))));
            KmaxDenHaanstat(fcstit) = max(KmaxDenHaanstat(fcstit), abs(KDenHaanfcst(t) - log(Ksim(t))));

            %iterate the average statistics
            pavgDenHaanstat(fcstit) = pavgDenHaanstat(fcstit) + abs(pDenHaanfcst(t) - log(psim(t)));
            KavgDenHaanstat(fcstit) = KavgDenHaanstat(fcstit) + abs(KDenHaanfcst(t) - log(Ksim(t)));

        end  %t

        %normalize the average statistics
        pavgDenHaanstat(fcstit) = pavgDenHaanstat(fcstit) / (Ttot-Terg-1);
        KavgDenHaanstat(fcstit) = KavgDenHaanstat(fcstit) / (Ttot-Terg-1);

        %record the change of the avg and max statistics
        if (fcstit==1) 
 
            pmaxDenHaanchange = pmaxDenHaanstat(fcstit);
            KmaxDenHaanchange = KmaxDenHaanstat(fcstit);
            pavgDenHaanchange = pavgDenHaanstat(fcstit);
            KavgDenHaanchange = KavgDenHaanstat(fcstit);

        elseif (fcstit>1)

            pmaxDenHaanchange = abs(pmaxDenHaanstat(fcstit)-pmaxDenHaanstat(fcstit-1));
            KmaxDenHaanchange = abs(KmaxDenHaanstat(fcstit)-KmaxDenHaanstat(fcstit-1));
            pavgDenHaanchange = abs(pavgDenHaanstat(fcstit)-pavgDenHaanstat(fcstit-1));
            KavgDenHaanchange = abs(KavgDenHaanstat(fcstit)-KavgDenHaanstat(fcstit-1));

        end

        %store the changes    
        pmaxDenHaanchangestore(fcstit) = pmaxDenHaanchange;
        KmaxDenHaanchangestore(fcstit) = KmaxDenHaanchange;
        pavgDenHaanchangestore(fcstit) = pavgDenHaanchange;
        KavgDenHaanchangestore(fcstit) = KavgDenHaanchange;

        
    %exit criterion for the fcstit looop
    exitflag = 0;
    %coefficient convergence
    if (GEerrorswitch==1)
        if (thetaKerr<fcsterrortol && thetaperr<fcsterrortol), exitflag = 1;end
    %RMSE convergence    
    elseif (GEerrorswitch==2)
        if (KRMSEchange<RMSEchangetol && pRMSEchange<RMSEchangetol), exitflag = 1;end
    %R2 convergence    
    elseif (GEerrorswitch==3)
        if (KR2change<R2changetol && pR2change<R2changetol), exitflag = 1;end
    %max Den Haan stat convergence    
    elseif (GEerrorswitch==4)
        if (KmaxDenHaanchange<maxDenHaanchangetol && pmaxDenHaanchange<maxDenHaanchangetol), exitflag = 1;  end  
    %avg Den Haan stat convergence    
    elseif (GEerrorswitch==5)
        if (KavgDenHaanchange<avgDenHaanchangetol && pavgDenHaanchange<avgDenHaanchangetol), exitflag = 1;  end      
    end
        
    
    disp(['fcst error = ',num2str(max(thetaKerr,thetaperr)), ' RMSE ',num2str(max(KRMSEchange,pRMSEchange)),...
        ' R2 change ',num2str(max(KR2change,pR2change)),' DH max ',num2str(max(KmaxDenHaanchange,pmaxDenHaanchange)),...
        ' DH avg ',num2str(max(KavgDenHaanchange,pavgDenHaanchange))])        
    
    if exitflag ==1
        break
    end
  
%     if (fcsterr<fcsterrortol) 
%        break 
%     end
    
    thetaKold = thetaKold + fcstdamp*(thetaK-thetaKold);
    thetapold = thetapold + fcstdamp*(thetap-thetapold);
      
%     save('Theta.mat','thetaKold','thetapold');
    
    clear Vold V_temp Rmat2 EVMAT_H2 Rmatp RHSMAT2 lval
    clear kprimeinddum EVMAT_H Eval ACval acvalp  
    clear temp yvalp lvalp ivalptemp distemp  kmat_ly 
 
    clear RHSMAT yval z1 y1 tempz polmatp_interp 
    clear Ivalp kprimeindp ivaltemp ivalp hvalp distemp 
    clear tempzs imattemp EVMAT Vtemp 
    clear kprimeindold  disttemp samp 
 % clear ACmat kprimemat Rmat ymat nmat kmat imat  zmat Amat Azmat_ly PiFmat PiSigmamat PiAmat distold  Ashocks Asim Asimpos
    
    save ('ks_het_firms_temp.mat');

    
end %of fcst loop

 
disp('Finished fcst rule iteration. The model is solved!')
toc;
disp(' ')

 
if (diary_record)
    diary OFF
   
    clear Vold V_temp Rmat2 EVMAT_H2 Rmatp RHSMAT2 lval
    clear kprimeinddum ivaltemp EVMAT_H Eval ACval ACmat  acvalp 
    clear Ashocks Asim Asimpos distold temp yvalp lvalp ivalptemp distemp kprimemat kmat_ly 
 
    clear RHSMAT yval z1 y1 tempz Rmat polmatp_interp 
    clear Ivalp kprimep kprimeindp ivaltemp ivalp hvalp distemp ymat 
    clear tempzs nmat  kmat imattemp imat EVMAT Vtemp 
    clear kprimeold kprimeindold  zmat Amat Azmat_ly disttemp samp PiFmat PiSigmamat PiAmat 
 
    
    save ('ks_het_firms_real_only.mat');
end


